# __author__ = 'tianfuzneg'
# !/usr/bin/python
# -*- coding:utf-8 -*-

########################################################################################
# 20240821
# from C:\Users\tianfu_zeng\OneDrive\PycharmProjects\HCC_ONT\ONT_somatic_3.0\SV_3.0_HBV_clipped-INS_2.0.py
########################################################################################
import os
from multiprocessing import Pool

samtools = "/data/fs01/biosoft/samtools-1.9/samtools"
clip_py = '/data/fs09/wangzf/nanopore/ztf/HCC/ONT/program/extract_clipped_seq_from_bam_by_ids_v3.py'
ins_py = "/data/fs09/wangzf/nanopore/ztf/HCC/ONT/program/extract_INS_seq.py"
blastn = "/data/fs01/wangzf/software/ncbi-blast-2.11.0+/bin/blastn"
blast_db = "/data/fs09/wangzf/nanopore/ztf/HCC/ref/HBV_ztf/BLAST/HBV_genome_21_newid"

# extract ins and clipped seq
# clip-INS seq re-alignment
def get_rnames(bam, txt):
    os.system("%s view %s | awk '{print $1}' | sort | uniq > %s" % (samtools, bam, txt))

def extract_ins_clipped_seq(rnames_bam, rnames_txt, clipped_fa, ins_fa):
    os.system("python %s -b %s -n %s -o %s" % (clip_py, rnames_bam, rnames_txt, clipped_fa))
    os.system("python %s -b %s -n %s -o %s" % (ins_py, rnames_bam, rnames_txt, ins_fa))

def run_blast(in_fa, blast_out):
    os.system("%s -task blastn -evalue 1e-5 -outfmt 6 -penalty -4 -reward 5 -gapopen 8 -gapextend 6 "
              "-max_target_seqs 1 -num_threads 4 "
              "-query %s -db %s -out %s" % (blastn, in_fa, blast_db, blast_out))

def split_fasta(fasta_file, threads, out_dir):
    tmp_dir = os.path.join(out_dir, 'tmp')
    if not os.path.exists(tmp_dir):
        os.makedirs(tmp_dir)
    chunk_files = []
    with open(fasta_file, 'r') as f:
        chunk = []
        chunk_count = 0
        for line in f:
            if line.startswith('>'):
                if len(chunk) >= threads:
                    chunk_file = f"{tmp_dir}/chunk_{chunk_count}.fasta"
                    with open(chunk_file, 'w') as chunk_f:
                        chunk_f.writelines(chunk)
                    chunk_files.append(chunk_file)
                    chunk = []
                    chunk_count += 1
            chunk.append(line)
        # Write the last chunk if it's not empty
        if chunk:
            chunk_file = f"{tmp_dir}/chunk_{chunk_count}.fasta"
            with open(chunk_file, 'w') as chunk_f:
                chunk_f.writelines(chunk)
            chunk_files.append(chunk_file)
    return chunk_files

def blast_fasta_multithread(fasta_file, threads, out_dir, blast_out):
    tmp_dir = os.path.join(out_dir, 'tmp')
    # Step 1: Split the FASTA file into smaller chunks
    chunk_files = split_fasta(fasta_file, threads, out_dir)
    # Step 2: Run BLASTn on each chunk in parallel
    pools = Pool(threads)
    for chunk_file1 in chunk_files:
        chunk_file1_blast = chunk_file1.replace('fasta', 'blast')
        pools.apply_async(run_blast, args=(chunk_file1, chunk_file1_blast))
    pools.close()
    pools.join()
    del pools
    # combine
    os.system(f"cat {tmp_dir}/*blast > {blast_out}")
    os.system(f"rm -rf {tmp_dir}")

# clip-INS seq re-alignment results summary (per human breakpoints)
# header: chrom, start, end, HBV_breakpoint sampleid read_name-type1
def get_loc_info(loc_info_x):
    # chr6_70300829-chr2_74273672
    c_x = loc_info_x.split('_')
    chr_x = c_x[0]
    start_x = c_x[1]
    end_x = str(int(c_x[1]) + 1)
    return [chr_x, start_x, end_x]

def get_bp(blast_out, hm_loc_bed, hm_hbv_txt, type1, sampleid):
    with open(hm_loc_bed, 'w') as out, open(hm_hbv_txt, 'w') as out1:
        if os.path.getsize(blast_out):
            with open(blast_out, 'r') as f:
                for line in f:
                    c = line.strip().split()
                    c1 = c[0].split(':')
                    read_name = c1[0]
                    hbv_name = c[1]
                    c2 = c1[1].split('_')
                    query_len = int(c2[2])
                    alig_len = float(c[3])
                    # filter balst: alignment_length/query_length > 0.5
                    if alig_len / query_len > 0.5:
                        # get human loc
                        if type1 == "INS":
                            hm_loc_list = [c2[3:]]
                        else:
                            hm_loc_list = []
                            c3 = '_'.join(c2[3:]).split('-')
                            for loc_info in c3:
                                if 'HBV' not in loc_info and loc_info != 'N':
                                    hm_loc_list.append(get_loc_info(loc_info))
                                elif loc_info == 'N':
                                    hm_loc_list.append('N')
                        # hbv_loc
                        hbv_loc = hbv_name + ':' + '_'.join(c[8:10])
                        hbv_loc1 = hbv_name + ':' + '_'.join([c[8], str(int(c[8]) + 1)])
                        hbv_loc2 = hbv_name + ':' + '_'.join([c[9], str(int(c[9]) + 1)])
                        hbv_loc_list = [hbv_loc1, hbv_loc2]
                        out_list2_info = []
                        for hm_loc in hm_loc_list:
                            if hm_loc != 'N':
                                hbv_loc_pair = hbv_loc_list[
                                    hm_loc_list.index(hm_loc)]  # breakpoints paired info
                                out_list = hm_loc + [hbv_loc_pair, sampleid, read_name]
                                out.write('\t'.join(out_list) + '\n')
                                out.flush()
                                out_list2_info.append(
                                    '%s|%s' % ("%s:%s-%s" % (hm_loc[0], hm_loc[1], hm_loc[2]), hbv_loc_pair))
                            else:
                                out_list2_info.append('N')
                        # piared txt, 1 event per line
                        out_list2 = [read_name, ','.join(out_list2_info), hbv_loc]
                        out1.write('\t'.join(out_list2) + '\n')

def run(options):
    sampleid = options.sampleid
    out_dir = options.out_dir
    rnames_bam = options.bam
    threads = options.threads
    rnames_txt = f"{out_dir}/{sampleid}_readsid.txt"
    # get read names
    get_rnames(rnames_bam, rnames_txt)
    # get ins and clipped fasta
    clipped_fa = f"{out_dir}/{sampleid}_clipped.fasta"
    ins_fa = f"{out_dir}/{sampleid}_INS.fasta"
    extract_ins_clipped_seq(rnames_bam, rnames_txt, clipped_fa, ins_fa)
    # ins blast and get breakpoint
    ins_fa_blast = f"{out_dir}/{sampleid}_INS.blast"
    blast_fasta_multithread(ins_fa, threads, out_dir, ins_fa_blast)
    hm_loc_bed_ins = f"{out_dir}/{sampleid}_INS_breakpoint_hm_hbv.bed"
    hm_hbv_txt_ins = f"{out_dir}/{sampleid}_INS_breakpoint_hm_hbv.txt"
    get_bp(ins_fa_blast, hm_loc_bed_ins, hm_hbv_txt_ins, "INS", sampleid)
    # clipped blast and get breakpoint
    clipped_fa_blast = f"{out_dir}/{sampleid}_clipped.blast"
    blast_fasta_multithread(clipped_fa, threads, out_dir, clipped_fa_blast)
    hm_loc_bed_clip = f"{out_dir}/{sampleid}_clipped_breakpoint_hm_hbv.bed"
    hm_hbv_txt_clip = f"{out_dir}/{sampleid}_clipped_breakpoint_hm_hbv.txt"
    get_bp(clipped_fa_blast, hm_loc_bed_clip, hm_hbv_txt_clip, "clipped", sampleid)


if __name__ == "__main__":
    from argparse import ArgumentParser
    parser = ArgumentParser(description='Identify HBV integration events by INS-clipped reads')
    parser.add_argument('-b', '--bam', help='bam file', required=True)
    parser.add_argument('-s', '--sampleid', help='output file prefix', required=True)
    parser.add_argument('-o', '--out_dir', help='output directory', required=True)
    parser.add_argument('-t', '--threads', help='number of threads (default: 1)', type=int, default=1)
    options = parser.parse_args()
    run(options)

